{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "zX6H255Agj68"
      },
      "source": [
        "\u003ca href=\"https://colab.research.google.com/github/google-research/text-to-text-transfer-transfrormer/blob/main/notebooks/t5-deploy.ipynb\" target=\"_parent\"\u003e\u003cimg src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/\u003e\u003c/a\u003e"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "Yo_HOomXe1f2"
      },
      "source": [
        "##### Copyright 2020 The T5 Authors\n",
        "\n",
        "Licensed under the Apache License, Version 2.0 (the \"License\");"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "Rz9fAJ8PexKB"
      },
      "outputs": [],
      "source": [
        "# Copyright 2020 The T5 Authors. All Rights Reserved.\n",
        "#\n",
        "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "#     http://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License.\n",
        "# =============================================================================="
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "geoZEiaLdGfR"
      },
      "source": [
        "# T5 SavedModel Export and Inference\n",
        "\n",
        "This notebook guides you through the process of exporting a [T5](https://github.com/google-research/text-to-text-transformer) `SavedModel` for inference. It uses the fine-tuned checkpoints in the [T5 Closed Book QA](https://github.com/google-research/google-research/tree/main/t5_closed_book_qa) repository for the [Natural Questions](https://ai.google.com/research/NaturalQuestions/) task as an example, but the same process will work for any model trained with the `t5` library.\n",
        "\n",
        "For more general usage of the `t5` library, please see the main [github repo](https://github.com/google-research/text-to-text-transformer) and fine-tuning [colab notebook](https://tiny.cc/t5-colab).\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "WtS5hODBKtR_"
      },
      "source": [
        "## Install T5 Library"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "UHCx-R4M-D0a"
      },
      "outputs": [],
      "source": [
        "!pip install -q t5"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "IFuyCiHpLCh7"
      },
      "source": [
        "## Export `SavedModel` to local storage\n",
        "\n",
        "NOTE: This will take a while for XL and XXL."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "both",
        "colab": {},
        "colab_type": "code",
        "id": "-Y7QSepo9a8H"
      },
      "outputs": [],
      "source": [
        "MODEL = \"small_ssm_nq\" #@param[\"small_ssm_nq\", \"t5.1.1.xl_ssm_nq\", \"t5.1.1.xxl_ssm_nq\"]\n",
        "\n",
        "import os\n",
        "\n",
        "saved_model_dir = f\"/content/{MODEL}\"\n",
        "\n",
        "!t5_mesh_transformer \\\n",
        "  --model_dir=\"gs://t5-data/pretrained_models/cbqa/{MODEL}\" \\\n",
        "  --use_model_api \\\n",
        "  --mode=\"export_predict\" \\\n",
        "  --export_dir=\"{saved_model_dir}\"\n",
        "\n",
        "saved_model_path = os.path.join(saved_model_dir, max(os.listdir(saved_model_dir)))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "JUH5BkcYK3At"
      },
      "source": [
        "## Load `SavedModel` and create helper functions for inference\n",
        "\n",
        "NOTE: This will take a while for XL and XXL."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "xiBSnGuu-em0"
      },
      "outputs": [],
      "source": [
        "import tensorflow as tf\n",
        "import tensorflow_text  # Required to run exported model.\n",
        "\n",
        "model = tf.saved_model.load(saved_model_path, [\"serve\"])\n",
        "\n",
        "def predict_fn(x):\n",
        " return model.signatures['serving_default'](tf.constant(x))['outputs'].numpy()\n",
        "\n",
        "def answer(question):\n",
        "  return predict_fn([question])[0].decode('utf-8')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "HSlRnhz7VpTu"
      },
      "source": [
        "## Ask some questions\n",
        "\n",
        "We must prefix each question with the `nq question:` prompt since T5 is a multitask model."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "CE1bO4hw--Zh"
      },
      "outputs": [],
      "source": [
        "for question in [\"nq question: where is google's headquarters\",\n",
        "                 \"nq question: what is the most populous country in the world\",\n",
        "                 \"nq question: name a member of the beatles\",\n",
        "                 \"nq question: how many teeth do humans have\"]:\n",
        "    print(answer(question))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "BlTOrC7iaCZD"
      },
      "source": [
        "## Package in Docker image"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "ndK6zIryaKTX"
      },
      "source": [
        "```bash\n",
        "MODEL_NAME=model-name\n",
        "SAVED_MODEL_PATH=/path/to/export/dir\n",
        "\n",
        "# Download the TensorFlow Serving Docker image and repo:\n",
        "docker pull tensorflow/serving:nightly\n",
        "\n",
        "# First, run a serving image as a daemon:\n",
        "docker run -d --name serving_base tensorflow/serving:nightly\n",
        "\n",
        "# Next, copy your `SavedModel` to the container's model folder:\n",
        "docker cp $SAVED_MODEL_PATH serving_base:/models/$MODEL_NAME\n",
        "\n",
        "# Now, commit the container that's serving your model:\n",
        "docker commit --change \"ENV MODEL_NAME ${MODEL_NAME}\" serving_base $MODEL_NAME\n",
        "\n",
        "# Finally, save the image to a tar file:\n",
        "docker save $MODEL_NAME -o $MODEL_NAME.tar\n",
        "\n",
        "# You can now stop `serving_base`:\n",
        "docker kill serving_base\n",
        "```"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "WYQiZke3nXD_"
      },
      "source": [
        "```bash\n",
        "docker run -t --rm -p 8501:8501 --name $MODEL_NAME-server $MODEL_NAME \u0026\n",
        "\n",
        "curl -d '{\"inputs\": [\"nq question: what is the most populous country?\"]}' \\\n",
        "    -X POST http://localhost:8501/v1/models/$MODEL_NAME:predict\n",
        "\n",
        "docker stop $MODEL_NAME-server\n",
        "```"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [
        "Yo_HOomXe1f2"
      ],
      "name": "t5-deploy",
      "private_outputs": true,
      "provenance": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
